{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "ImageNet splitter for milking CowMask OSS - sharded",
      "provenance": [
        {
          "file_id": "1O0tN4minvv8uQILKcB6gd5_1xef-AVzA",
          "timestamp": 1579254865259
        },
        {
          "file_id": "1IWougL8z9OR23Zya_rK9yJJ7YRAuAN7-",
          "timestamp": 1572543309030
        },
        {
          "file_id": "1U362onxwK1Rap22Nuf-CY7wciX0B6RXy",
          "timestamp": 1572430150432
        },
        {
          "file_id": "1mun4l-faGn_VS0bgiWNu_LWRvu6WX32T",
          "timestamp": 1572354525769
        },
        {
          "file_id": "1rYqxyRlKW1AeLyMSnFE-Mv332_5GmtMc",
          "timestamp": 1572262833111
        }
      ],
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/dm_python:dm_notebook3",
        "kind": "private"
      }
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8SNj8n5PpNMG",
        "colab_type": "text"
      },
      "source": [
        "# Generate splits of ImageNet and save in sharded TFRECORD files\n",
        "\n",
        "### Licensed under the Apache License, Version 2.0"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BDkLxCCCZD5w",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import os\n",
        "os.environ['UNITTEST_ON_FORGE'] = '1'\n",
        "\n",
        "import string\n",
        "import random\n",
        "import pickle\n",
        "import zipfile\n",
        "import io\n",
        "import itertools\n",
        "\n",
        "import time\n",
        "import tensorflow as tf\n",
        "import numpy as np\n",
        "from sklearn.model_selection import StratifiedShuffleSplit\n",
        "from matplotlib import pyplot as plt\n",
        "import tqdm\n",
        "import tensorflow_datasets as tfds\n",
        "\n",
        "tf.enable_eager_execution()\n",
        "\n",
        "\n",
        "import distutils\n",
        "if distutils.version.LooseVersion(tf.__version__) < '1.14':\n",
        "    raise Exception('This notebook is compatible with TensorFlow 1.14 or higher, for TensorFlow 1.13 or lower please use the previous version at https://github.com/tensorflow/tpu/blob/r1.13/tools/colab/fashion_mnist.ipynb')\n",
        "\n",
        "  \n",
        "print('Tensorflow version {}'.format(tf.__version__))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZF5B2Jka7i11",
        "colab_type": "text"
      },
      "source": [
        "## Path to ImageNet on Placer"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "aufrmBh0fIyj",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# EDIT THESE\n",
        "IMAGENET_TFRECORDS_SOURCE_PATH = r'<source path in here>'\n",
        "OUT_SUBSET_SHARDS_PATH = r'<destination path in here>'\n",
        "IMAGENET_SIZE = 1281167\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tUKaAvlT7pT1",
        "colab_type": "text"
      },
      "source": [
        "## Description of features in ImageNet `tfrecord` files"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IXUGZb_q7oo0",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "feature_description = {\n",
        "    'label': tf.io.FixedLenFeature([], tf.int64, default_value=0),\n",
        "    'image': tf.io.FixedLenFeature([], tf.string),\n",
        "    'file_name': tf.io.FixedLenFeature([], tf.string),\n",
        "}"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TKoB9dzE7t7V",
        "colab_type": "text"
      },
      "source": [
        "## Load `tfrecord` dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2hbW20x7o01B",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "train_files = [f for f in tf.io.gfile.listdir(IMAGENET_TFRECORDS_SOURCE_PATH) if f.startswith('imagenet2012-train.tfrecord')]\n",
        "train_paths = [IMAGENET_TFRECORDS_SOURCE_PATH + f for f in train_files]\n",
        "ds = tf.data.TFRecordDataset(train_paths)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_g7o-BcS710M",
        "colab_type": "text"
      },
      "source": [
        "## Get ground truth labels and filenames for all samples in training set"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qgYq8K-M1Hai",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def get_labels(ds, N):\n",
        "  it = ds.prefetch(tf.data.experimental.AUTOTUNE).make_one_shot_iterator()\n",
        "\n",
        "  all_ys = []\n",
        "  all_fns = []\n",
        "  for _ in tqdm.tqdm(range(N)):\n",
        "    sample = it.next()\n",
        "    all_ys.append(sample['label'])\n",
        "    all_fns.append(sample[SAMPLE_FILENAME_ATTR].numpy())\n",
        "\n",
        "  return np.array(all_ys), all_fns\n",
        "\n",
        "all_y, all_filenames = get_labels(tfds.load(name='imagenet2012', split=tfds.Split.TRAIN), IMAGENET_SIZE)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KvZXyXGH8eue",
        "colab_type": "text"
      },
      "source": [
        "## Define functions for getting ImageNet subsets"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6uaH1O9bBSpY",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def random_name():\n",
        "  allchar = string.ascii_letters\n",
        "  name = \"\".join([random.choice(allchar) for x in range(8)])\n",
        "  return name\n",
        "\n",
        "def dataset_subset_by_file_name(in_ds, subset_filenames):\n",
        "  kv_init = tf.lookup.KeyValueTensorInitializer(np.array(subset_filenames), np.ones((len(subset_filenames),), dtype=int),\n",
        "                                                key_dtype=tf.string, value_dtype=tf.int64, name=random_name())\n",
        "  ht = tf.lookup.StaticHashTable(kv_init, 0, name=random_name())\n",
        "\n",
        "  def pred_fn(x):\n",
        "    f = tf.io.parse_single_example(x, feature_description)\n",
        "    return tf.equal(ht.lookup(f['file_name']), 1)\n",
        "\n",
        "  return in_ds.filter(pred_fn)\n",
        "\n",
        "\n",
        "def imagenet_subset(n_samples, seed):\n",
        "  splitter = StratifiedShuffleSplit(1, test_size=n_samples, random_state=seed)\n",
        "  _, ndx = next(splitter.split(data_index['y'], data_index['y']))\n",
        "\n",
        "  sub_fn = [all_filenames[int(i)] for i in ndx]\n",
        "\n",
        "  return dataset_subset_by_file_name(ds, sub_fn)\n",
        "\n",
        "\n",
        "\n",
        "def write_imagenet_subset_by_fn_sharded(out_dir, name, filenames, num_shards, group='brain-ams'):\n",
        "  # out_path is a directory name\n",
        "  out_path = os.path.join(out_dir, name)\n",
        "  if tf.io.gfile.exists(out_path):\n",
        "    print('Skipping already existing {}'.format(out_path))\n",
        "    return\n",
        "  print('Generating {} ...'.format(out_path))\n",
        "  tf.io.gfile.mkdir(out_path)\n",
        "  t1 = time.time()\n",
        "  sub_ds = dataset_subset_by_file_name(ds, filenames)\n",
        "\n",
        "  shard_base_path = os.path.join(out_path, '{}.tfrecord-'.format(name))\n",
        "  def reduce_func(key, dataset):\n",
        "    filename = tf.strings.join([shard_base_path, tf.strings.as_string(key)])\n",
        "    writer = tf.data.experimental.TFRecordWriter(filename)\n",
        "    writer.write(dataset.map(lambda _, x: x))\n",
        "    return tf.data.Dataset.from_tensors(filename)\n",
        "\n",
        "  write_ds = sub_ds.enumerate()\n",
        "  write_ds = write_ds.apply(tf.data.experimental.group_by_window(\n",
        "    lambda i, _: i % num_shards, reduce_func, tf.int64.max\n",
        "  ))\n",
        "  for x in write_ds:\n",
        "    pass\n",
        "\n",
        "  t2 = time.time()\n",
        "  print('Built subset {} in {:.2f}s'.format(\n",
        "      name, t2 - t1\n",
        "  ))\n",
        "\n",
        "\n",
        "def write_imagenet_subset_sharded(out_dir, name, ds_filenames, ds_y, n_samples, seed, num_shards, group='brain-ams'):\n",
        "  splitter = StratifiedShuffleSplit(1, test_size=n_samples, random_state=seed)\n",
        "  _, ndx = next(splitter.split(ds_y, ds_y))\n",
        "\n",
        "  sub_fn = [ds_filenames[int(i)] for i in ndx]\n",
        "\n",
        "  write_imagenet_subset_by_fn_sharded(out_dir, name, sub_fn, num_shards, group=group)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lXxmKqER9JaD",
        "colab_type": "text"
      },
      "source": [
        "## Build our subsets"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A8T0XSue7rkW",
        "colab_type": "text"
      },
      "source": [
        "## Split into train and val"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fZjgHHZVpMSK",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "n_train_val = len(all_filenames)\n",
        "N_VAL = 50000\n",
        "VAL_SEED = 131\n",
        "\n",
        "# Validation set\n",
        "trainval_splitter = StratifiedShuffleSplit(1, test_size=N_VAL, random_state=VAL_SEED)\n",
        "train_ndx, val_ndx = next(trainval_splitter.split(data_index['y'], data_index['y']))\n",
        "\n",
        "train_fn = [all_filenames[int(i)] for i in train_ndx]\n",
        "train_y = all_y[train_ndx]\n",
        "val_fn = [all_filenames[int(i)] for i in val_ndx]\n",
        "val_y = all_y[val_ndx]\n",
        "\n",
        "print('Split train-val set of {} into {} train and {} val'.format(\n",
        "    n_train_val, len(train_fn), len(val_fn)\n",
        "))\n",
        "\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wyjgQN_0vZgB",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "split_path = os.path.join(OUT_SUBSET_SHARDS_PATH, 'imagenet_tv{}s{}_split.pkl'.format(N_VAL, VAL_SEED))\n",
        "with tf.io.gfile.GFile(split_path, mode='wb') as f_split:\n",
        "  split_data = dict(train_fn=train_fn, train_y=train_y, val_fn=val_fn, val_y=val_y)\n",
        "  pickle.dump(split_data, f_split)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BKGFLWA760DJ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "\n",
        "\n",
        "# Val\n",
        "write_imagenet_subset_by_fn_sharded(OUT_SUBSET_SHARDS_PATH, 'imagenet_tv{}s{}_val'.format(N_VAL, VAL_SEED),\n",
        "                                    val_fn, num_shards=256)\n",
        "\n",
        "\n",
        "# 1% subsets\n",
        "write_imagenet_subset_sharded(OUT_SUBSET_SHARDS_PATH, 'imagenet_tv{}s{}_{}_seed{}'.format(N_VAL, VAL_SEED, len(train_fn)//100, 12345),\n",
        "                              train_fn, train_y, len(train_fn)//100, 12345, num_shards=256)\n",
        "write_imagenet_subset_sharded(OUT_SUBSET_SHARDS_PATH, 'imagenet_tv{}s{}_{}_seed{}'.format(N_VAL, VAL_SEED, len(train_fn)//100, 23456),\n",
        "                              train_fn, train_y, len(train_fn)//100, 23456, num_shards=256)\n",
        "write_imagenet_subset_sharded(OUT_SUBSET_SHARDS_PATH, 'imagenet_tv{}s{}_{}_seed{}'.format(N_VAL, VAL_SEED, len(train_fn)//100, 34567),\n",
        "                              train_fn, train_y, len(train_fn)//100, 34567, num_shards=256)\n",
        "write_imagenet_subset_sharded(OUT_SUBSET_SHARDS_PATH, 'imagenet_tv{}s{}_{}_seed{}'.format(N_VAL, VAL_SEED, len(train_fn)//100, 45678),\n",
        "                              train_fn, train_y, len(train_fn)//100, 45678, num_shards=256)\n",
        "write_imagenet_subset_sharded(OUT_SUBSET_SHARDS_PATH, 'imagenet_tv{}s{}_{}_seed{}'.format(N_VAL, VAL_SEED, len(train_fn)//100, 56789),\n",
        "                              train_fn, train_y, len(train_fn)//100, 56789, num_shards=256)\n",
        "\n",
        "\n",
        "# 10% subsets\n",
        "write_imagenet_subset_sharded(OUT_SUBSET_SHARDS_PATH, 'imagenet_tv{}s{}_{}_seed{}'.format(N_VAL, VAL_SEED, len(train_fn)//10, 12345),\n",
        "                              train_fn, train_y, len(train_fn)//10, 12345, num_shards=256)\n",
        "write_imagenet_subset_sharded(OUT_SUBSET_SHARDS_PATH, 'imagenet_tv{}s{}_{}_seed{}'.format(N_VAL, VAL_SEED, len(train_fn)//10, 23456),\n",
        "                              train_fn, train_y, len(train_fn)//10, 23456, num_shards=256)\n",
        "write_imagenet_subset_sharded(OUT_SUBSET_SHARDS_PATH, 'imagenet_tv{}s{}_{}_seed{}'.format(N_VAL, VAL_SEED, len(train_fn)//10, 34567),\n",
        "                              train_fn, train_y, len(train_fn)//10, 34567, num_shards=256)\n",
        "write_imagenet_subset_sharded(OUT_SUBSET_SHARDS_PATH, 'imagenet_tv{}s{}_{}_seed{}'.format(N_VAL, VAL_SEED, len(train_fn)//10, 45678),\n",
        "                              train_fn, train_y, len(train_fn)//10, 45678, num_shards=256)\n",
        "write_imagenet_subset_sharded(OUT_SUBSET_SHARDS_PATH, 'imagenet_tv{}s{}_{}_seed{}'.format(N_VAL, VAL_SEED, len(train_fn)//10, 56789),\n",
        "                              train_fn, train_y, len(train_fn)//10, 56789, num_shards=256)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cp8pjcHx7kq-",
        "colab_type": "text"
      },
      "source": [
        "## Write train subsets (no validation split; use ImageNet validation as evaluation set)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XB9J47MV7Uea",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# 1% subsets\n",
        "write_imagenet_subset_sharded(OUT_SUBSET_SHARDS_PATH, 'imagenet_{}_seed{}'.format(IMAGENET_SIZE//100, 12345),\n",
        "                              all_filenames, len(all_filenames)//100, 12345, num_shards=256)\n",
        "write_imagenet_subset_sharded(OUT_SUBSET_SHARDS_PATH, 'imagenet_{}_seed{}'.format(IMAGENET_SIZE//100, 23456),\n",
        "                              all_filenames, len(all_filenames)//100, 23456, num_shards=256)\n",
        "write_imagenet_subset_sharded(OUT_SUBSET_SHARDS_PATH, 'imagenet_{}_seed{}'.format(IMAGENET_SIZE//100, 34567),\n",
        "                              all_filenames, len(all_filenames)//100, 34567, num_shards=256)\n",
        "write_imagenet_subset_sharded(OUT_SUBSET_SHARDS_PATH, 'imagenet_{}_seed{}'.format(IMAGENET_SIZE//100, 45678),\n",
        "                              all_filenames, len(all_filenames)//100, 45678, num_shards=256)\n",
        "write_imagenet_subset_sharded(OUT_SUBSET_SHARDS_PATH, 'imagenet_{}_seed{}'.format(IMAGENET_SIZE//100, 56789),\n",
        "                              all_filenames, len(all_filenames)//100, 56789, num_shards=256)\n",
        "\n",
        "\n",
        "# 10% subsets\n",
        "write_imagenet_subset_sharded(OUT_SUBSET_SHARDS_PATH, 'imagenet_{}_seed{}'.format(IMAGENET_SIZE//10, 12345),\n",
        "                              all_filenames, len(all_filenames)//10, 12345, num_shards=256)\n",
        "write_imagenet_subset_sharded(OUT_SUBSET_SHARDS_PATH, 'imagenet_{}_seed{}'.format(IMAGENET_SIZE//10, 23456),\n",
        "                              all_filenames, len(all_filenames)//10, 23456, num_shards=256)\n",
        "write_imagenet_subset_sharded(OUT_SUBSET_SHARDS_PATH, 'imagenet_{}_seed{}'.format(IMAGENET_SIZE//10, 34567),\n",
        "                              all_filenames, len(all_filenames)//10, 34567, num_shards=256)\n",
        "write_imagenet_subset_sharded(OUT_SUBSET_SHARDS_PATH, 'imagenet_{}_seed{}'.format(IMAGENET_SIZE//10, 45678),\n",
        "                              all_filenames, len(all_filenames)//10, 45678, num_shards=256)\n",
        "write_imagenet_subset_sharded(OUT_SUBSET_SHARDS_PATH, 'imagenet_{}_seed{}'.format(IMAGENET_SIZE//10, 56789),\n",
        "                              all_filenames, len(all_filenames)//10, 56789, num_shards=256)\n"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}